#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################################
# Copyright (c) 2020 Huawei Technologies Co.,Ltd.
#
# openGauss is licensed under Mulan PSL v2.
# You can use this software according to the terms
# and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#
#          http://license.coscl.org.cn/MulanPSL2
#
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS,
# WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
# ----------------------------------------------------------------------------
# Description  : gs_expansion is a utility to expansion standby node databases
#############################################################################

import os
import sys
import subprocess
import socket
package_path = os.path.dirname(os.path.realpath(__file__))
ld_path = package_path + "/gspylib/clib"
if 'LD_LIBRARY_PATH' not in os.environ:
    os.environ['LD_LIBRARY_PATH'] = ld_path
    os.execve(os.path.realpath(__file__), sys.argv, os.environ)
if not os.environ.get('LD_LIBRARY_PATH').startswith(ld_path):
    os.environ['LD_LIBRARY_PATH'] = \
        ld_path + ":" + os.environ['LD_LIBRARY_PATH']
    os.execve(os.path.realpath(__file__), sys.argv, os.environ)

sys.path.append(sys.path[0])
from gspylib.common.DbClusterInfo import dbClusterInfo, \
    checkPathVaild, dbNodeInfo, instanceInfo
from gspylib.common.GaussLog import GaussLog
from gspylib.common.Common import DefaultValue
from gspylib.common.ErrorCode import ErrorCode
from gspylib.common.ParallelBaseOM import ParallelBaseOM
from gspylib.common.ParameterParsecheck import Parameter
from impl.expansion.ExpansionImpl import ExpansionImpl
from impl.expansion.expansion_impl_with_cm import ExpansionImplWithCm
from impl.expansion.expansion_impl_with_cm_local import ExpansionImplWithCmLocal
from domain_utils.cluster_file.cluster_config_file import ClusterConfigFile
from domain_utils.cluster_file.cluster_log import ClusterLog
from base_utils.os.env_util import EnvUtil

ENV_LIST = ["MPPDB_ENV_SEPARATE_PATH", "GPHOME", "PATH",
            "LD_LIBRARY_PATH", "PYTHONPATH", "GAUSS_WARNING_TYPE",
            "GAUSSHOME", "PATH", "LD_LIBRARY_PATH",
            "S3_CLIENT_CRT_FILE", "GAUSS_VERSION", "PGHOST",
            "GS_CLUSTER_NAME", "GAUSSLOG", "GAUSS_ENV", "umask"]
# The following attributes are skipped because the information
# in the static configuration file of the OpenGauss is incorrect.
ABORT_CHECK_PROPERTY = ["xmlFile", "id", "gtmNum", "instanceId",
                        "masterBasePorts", "standbyBasePorts",
                        "instanceType", "_dbClusterInfo__newInstanceId",
                        "_dbClusterInfo__newMirrorId", "version",
                        "installTime", "localNodeId", "nodeCount",
                        "_dbClusterInfo__newGroupId", "cmsNum", "datadir",
                        "enable_dcf", "dcf_config", "dcf_data_path", "cmscount", "casecadeRole"]
IGNORE_CHECK_KEY = ["cascadeRole"]

class Expansion(ParallelBaseOM):
    """
    """
    
    def __init__(self):
        """
        """
        ParallelBaseOM.__init__(self)
        # new added standby node backip list
        self.newHostList = []
        self.clusterInfoDict = {}
        self.backIpNameMap = {}
        self.newHostCasRoleMap = {}
        self.hostAzNameMap = {}
        self.packagepath = os.path.realpath(
                    os.path.join(os.path.realpath(__file__), "../../"))

        self.standbyLocalMode = False
        self.time_out = None
        self.envFile = EnvUtil.getEnv("MPPDB_ENV_SEPARATE_PATH")

    def usage(self):
        """
gs_expansion is a utility to expansion standby node for a cluster.

Usage:
    gs_expansion -? | --help
    gs_expansion -V | --version
    gs_expansion -U USER -G GROUP -X XMLFILE -h nodeList [-L] 
General options:
    -U                                 Cluster user.
    -G                                 Group of the cluster user.
    -X                                 Path of the XML configuration file.
    -h                                 New standby node backip list. 
                                       Separate multiple nodes with commas (,).
                                       such as '-h 192.168.0.1,192.168.0.2'
    -L                                 The standby database installed with 
                                       local mode.
        --time-out=SECS                Maximum waiting time when send the
                                       packages to new standby nodes.                         
    -?, --help                         Show help information for this
                                       utility, and exit the command line mode.
    -V, --version                      Show version information.
        """
        print(self.usage.__doc__)

    def parseCommandLine(self):
        """
        parse parameter from command line
        """
        ParaObj = Parameter()
        ParaDict = ParaObj.ParameterCommandLine("expansion")  

        # parameter -h or -?
        if (ParaDict.__contains__("helpFlag")):
            self.usage()
            sys.exit(0)
        # Resolves command line arguments
        # parameter -U
        if (ParaDict.__contains__("user")):
            self.user = ParaDict.get("user")
            DefaultValue.checkPathVaild(self.user)
        # parameter -G
        if (ParaDict.__contains__("group")):
            self.group = ParaDict.get("group")
        # parameter -X
        if (ParaDict.__contains__("confFile")):
            self.xmlFile = ParaDict.get("confFile")
        # parameter -L
        if (ParaDict.__contains__("localMode")):
            self.localMode = ParaDict.get("localMode")
            self.standbyLocalMode = ParaDict.get("localMode")
        # parameter -l
        if (ParaDict.__contains__("logFile")):
            self.logFile = ParaDict.get("logFile")
        #parameter -h
        if (ParaDict.__contains__("nodename")):
            self.newHostList = ParaDict.get("nodename")
        # parameter --time-out
        if (ParaDict.__contains__("time_out")):
            self.time_out = ParaDict.get("time_out")

    def checkParameters(self):
        """
        function: Check parameter from command line
        input: NA
        output: NA
        """
        
        # check user | group | xmlfile | node
        if len(self.user) == 0:
            GaussLog.exitWithError(ErrorCode.GAUSS_357["GAUSS_35701"] % "-U")
        if len(self.group) == 0:
            GaussLog.exitWithError(ErrorCode.GAUSS_357["GAUSS_35701"] % "-G")
        if len(self.xmlFile) == 0:
            GaussLog.exitWithError(ErrorCode.GAUSS_357["GAUSS_35701"] % "-X")
        if len(self.newHostList) == 0:
            GaussLog.exitWithError(ErrorCode.GAUSS_357["GAUSS_35701"] % "-h")
        # check if upgrade action is exist
        if DefaultValue.isUnderUpgrade(self.user):
            GaussLog.exitWithError(ErrorCode.GAUSS_529["GAUSS_52936"])

        if (self.time_out is None):
            self.time_out = DefaultValue.TIMEOUT_CLUSTER_START
        else:
            # The timeout parameter must be a pure number
            if (not str(self.time_out).isdigit()):
                GaussLog.exitWithError(
                    ErrorCode.GAUSS_500["GAUSS_50003"] %
                    ("-time-out", "a nonnegative integer"))
            self.time_out = int(self.time_out)
            # The timeout parameter must be greater than 0
            # The timeout parameter must be less than the integer maximum
            if (self.time_out <= 0 or self.time_out
                    >= 2147483647):
                GaussLog.exitWithError(ErrorCode.GAUSS_500["GAUSS_50004"]
                                       % "-time-out")

    def _get_node_dn_port(self, node_name):
        """
        Get data node port
        """
        if not self.clusterInfo.clusterType == DefaultValue.CLUSTER_TYPE_SINGLE_INST:
            self.logger.log("The cluster type is not single-inst.")
            raise Expansion("Cluster type is not single-inst.")
        node = self.clusterInfo.getDbNodeByName(node_name)
        if node.datanodes:
            return node.datanodes[0].port
        return None

    def _getClusterInfoDict(self):
        self.check_xml_config()
        clusterInfo = ExpansionClusterInfo()
        self.clusterInfo = clusterInfo
        hostNameIpDict = clusterInfo.initFromXml(self.xmlFile)
        clusterDict = clusterInfo.get_cluster_directory_dict()
        self.nodeNameList = clusterInfo.getClusterNodeNames()

        # get corepath and toolpath from xml file
        corePath = clusterInfo.readClustercorePath(self.xmlFile)
        toolPath = clusterInfo.getToolPath(self.xmlFile)
        # parse xml file and cache node info
        clusterInfoDict = {}
        clusterInfoDict["appPath"] = clusterDict["appPath"][0]
        clusterInfoDict["logPath"] = clusterDict["logPath"][0]
        clusterInfoDict["corePath"] = corePath
        clusterInfoDict["toolPath"] = toolPath
        for nodeName in self.nodeNameList:
            hostInfo = hostNameIpDict[nodeName]
            ipList = hostInfo[0]
            backIp = ipList[0]
            sshIp = ipList[1]
            port = self._get_node_dn_port(nodeName)
            if clusterDict[nodeName]["dn"]["data_dir"]:
                data_node = clusterDict[nodeName]["dn"]["data_dir"][0]
            else:
                data_node = ""
            dbNode = clusterInfo.getDbNodeByName(nodeName)
            clusterInfoDict[nodeName] = {
                "backIp": backIp,
                "sshIp": sshIp,
                "port": port,
                "localport": int(port) + 1,
                "localservice": int(port) + 4,
                "heartBeatPort": int(port) + 3,
                "dataNode": data_node,
                "instanceType": -1,
                "azPriority": dbNode.azPriority
            }
        
        nodeIdList = clusterInfo.getClusterNodeIds()
        for id in nodeIdList:
            insType = clusterInfo.getdataNodeInstanceType(id)
            hostName = clusterInfo.getHostNameByNodeId(id)
            clusterInfoDict[hostName]["instanceType"] = insType
        self.clusterInfoDict = clusterInfoDict

    def initLogs(self):
        """
        init log file
        """
        # if no log file
        if (self.logFile == ""):
            self.logFile = ClusterLog.getOMLogPath(
                DefaultValue.EXPANSION_LOG_FILE, self.user, "",
                self.xmlFile)
        # if not absolute path
        if (not os.path.isabs(self.logFile)):
            GaussLog.exitWithError(ErrorCode.GAUSS_502["GAUSS_50213"] % "log")

        self.initLogger("gs_expansion")
        # change the owner of gs_expansion.log to the db user
        if os.path.isfile(self.logger.logFile):
            (status, output) = subprocess.getstatusoutput("ls gs_expansion* -al | cut -d' ' -f3")
            if output != self.user:
                subprocess.getstatusoutput("chown {}:{} {}".format(self.user, self.group, self.logger.logFile))
        self.logger.ignoreErr = True
    
    def getExpansionInfo(self):
        self._getClusterInfoDict()
        self._getBackIpNameMap()
        self._getHostAzNameMap()
        self._getNewHostCasRoleMap()

    def checkXmlIncludeNewHost(self):
        """
        check parameter node must in xml config file
        """
        backIpList = self.clusterInfo.getClusterBackIps()
        for nodeIp in self.newHostList:
            if nodeIp not in backIpList:
                GaussLog.exitWithError(ErrorCode.GAUSS_357["GAUSS_35702"] %
                    nodeIp)

    def _getBackIpNameMap(self):
        backIpList = self.clusterInfo.getClusterBackIps()
        for backip in backIpList:
            self.backIpNameMap[backip] = \
                self.clusterInfo.getNodeNameByBackIp(backip)

    def checkExecutingUser(self):
        """
        check whether current user executing this command is root
        """
        if os.getuid() != 0:
            GaussLog.exitWithError(ErrorCode.GAUSS_501["GAUSS_50104"])

    def checkExecutingHost(self):
        """
        check whether current host is primary host
        """
        currentHost = socket.gethostname()
        primaryHost = ""
        for nodeName in self.nodeNameList:
            if self.clusterInfoDict[nodeName]["instanceType"] \
                    == 0:
                primaryHost = nodeName
                break
        if currentHost != primaryHost:
            GaussLog.exitWithError(ErrorCode.GAUSS_501["GAUSS_50110"] %
                (currentHost + ", which is not primary"))

    def checkTrust(self, hostList = None):
        """
        check trust between primary/current host and every host in hostList
        """
        if hostList == None:
            hostList = list(self.nodeNameList)
            backIpList = self.clusterInfo.getClusterBackIps()
            hostList += backIpList
        gpHome = EnvUtil.getEnv("GPHOME")
        psshPath = "python3 %s/script/gspylib/pssh/bin/pssh" % gpHome
        rootSSHExceptionHosts = []
        individualSSHExceptionHosts = []
        for host in hostList:
            # check root's trust
            checkRootTrustCmd = "%s -s -H %s 'pwd'" % (psshPath, host)
            (status, output) = subprocess.getstatusoutput(checkRootTrustCmd)
            if status != 0:
                rootSSHExceptionHosts.append(host)
            # check individual user's trust
            checkUserTrustCmd = "su - %s -c '%s -s -H %s pwd'" % (
                self.user, psshPath, host)
            (status, output) = subprocess.getstatusoutput(checkUserTrustCmd)
            if status != 0:
                individualSSHExceptionHosts.append(host)
        # output ssh exception info if ssh connect failed
        if rootSSHExceptionHosts or individualSSHExceptionHosts:
            sshExceptionInfo = ""
            if rootSSHExceptionHosts:
                sshExceptionInfo += "\n"
                sshExceptionInfo += ", ".join(rootSSHExceptionHosts)
                sshExceptionInfo += " by root"
            if individualSSHExceptionHosts:
                sshExceptionInfo += "\n"
                sshExceptionInfo += ", ".join(individualSSHExceptionHosts)
                sshExceptionInfo += " by individual user"
            GaussLog.exitWithError(ErrorCode.GAUSS_511["GAUSS_51100"] %
                sshExceptionInfo)

    def checkEnvfile(self):
        """
        check whether env file is sourced
        check whether info in XML is consistent with environment variable
        """
        self.logger.debug("Checking environment variable.")
        if not EnvUtil.getEnv("GPHOME"):
            GaussLog.exitWithError(ErrorCode.GAUSS_518["GAUSS_51802"] % (
                "\"GPHOME\", please import environment variable"))
        if not EnvUtil.getEnv("GAUSSHOME"):
            GaussLog.exitWithError(ErrorCode.GAUSS_518["GAUSS_51802"] % (
                "\"GAUSSHOME\", please import environment variable"))
        if not EnvUtil.getEnv("PGHOST"):
            GaussLog.exitWithError(ErrorCode.GAUSS_518["GAUSS_51802"] % (
                "\"PGHOST\", please import environment variable"))
        clusterInfoDict = self.clusterInfoDict
        toolPath = EnvUtil.getEnv("GPHOME")
        appPath = EnvUtil.getEnv("GAUSSHOME")
        if toolPath != clusterInfoDict["toolPath"]:
            GaussLog.exitWithError(ErrorCode.GAUSS_357["GAUSS_35711"] % "toolPath")
        if appPath != clusterInfoDict["appPath"]:
            GaussLog.exitWithError(ErrorCode.GAUSS_357["GAUSS_35711"] % "appPath")

    def _getHostAzNameMap(self):
        """
        get azName of all hosts
        """
        for dbnode in self.clusterInfo.dbNodes:
            self.hostAzNameMap[dbnode.backIps[0]] = dbnode.azName
    
    def _getNewHostCasRoleMap(self):
        """
        get cascadeRole of newHosts
        """
        for dbnode in self.clusterInfo.dbNodes:
            if dbnode.backIps[0] in self.newHostList:
                self.newHostCasRoleMap[dbnode.backIps[0]] = dbnode.cascadeRole

    def check_cm_component(self):
        """
        Init cluster information
        """
        db_cluster_info = dbClusterInfo()
        db_cluster_info.initFromStaticConfig(self.user)
        if DefaultValue.get_cm_server_num_from_static(db_cluster_info) > 0:
            return True
        return False

    def check_xml_config(self):
        """
        Check XML configuration
        """
        expand_cluster_info = ExpansionClusterInfo()
        expand_cluster_info.initFromXml(self.xmlFile)
        xml_cluster_info = dbClusterInfo()
        xml_cluster_info.initFromXml(self.xmlFile)
        static_cluster_info = dbClusterInfo()
        static_cluster_info.initFromStaticConfig(self.user)
        expand_cluster_info.compare_cluster_info(static_cluster_info,
                                                 xml_cluster_info)

    def expand_run(self):
        """
        This is expansion frame start
        """
        if self.check_cm_component() and self.standbyLocalMode:
            expand_impl = ExpansionImplWithCmLocal(self)
            self.logger.log("Start expansion with cluster manager component on standby node.")
        elif self.check_cm_component():
            expand_impl = ExpansionImplWithCm(self)
            self.logger.log("Start expansion with cluster manager component.")
        else:
            expand_impl = ExpansionImpl(self)
            self.logger.log("Start expansion without cluster manager component.")
        expand_impl.run()

class ExpansionClusterInfo(dbClusterInfo):

    def __init__(self):
        dbClusterInfo.__init__(self)
        self.collect_compare_result = list()

    def _remove_normal_msg(self, key):
        """
        Remove normal message in alarm_message
        """
        if key in self.collect_compare_result:
            self.collect_compare_result.remove(key)

    def _add_alarm_msg(self, key):
        """
        Remove normal message in alarm_message
        """
        self.collect_compare_result.append(key)
    
    def getToolPath(self, xmlFile):
        """
        function : Read tool path from default xml file
        input : String
        output : String
        """
        ClusterConfigFile.setDefaultXmlFile(xmlFile)
        # read gaussdb tool path from xml file
        (retStatus, retValue) = ClusterConfigFile.readOneClusterConfigItem(
            ClusterConfigFile.initParserXMLFile(xmlFile), "gaussdbToolPath", "cluster")
        if retStatus != 0:
            raise Exception(ErrorCode.GAUSS_512["GAUSS_51200"]
                            % "gaussdbToolPath" + " Error: \n%s" % retValue)
        toolPath = os.path.normpath(retValue)
        checkPathVaild(toolPath)
        return toolPath

    def find_right_node(self, a_node, b_list):
        """
        Find node in b_list
        """
        for node in b_list:
            if node.name == a_node.name:
                return node
        raise Exception("Node {0} not config in XML.".format(a_node.name))

    def find_right_instance(self, a_inst, b_list):
        """
        Find instance in b_list
        """
        for inst in b_list:
            if inst.instanceId == a_inst.instanceId:
                return inst
            elif inst.instanceRole in [DefaultValue.INSTANCE_ROLE_CMSERVER,
                                       DefaultValue.INSTANCE_ROLE_CMAGENT]:
                return inst
        raise Exception("Instance {0} not config in XML.".format(inst.instanceId))

    def compare_list(self, a_list, b_list):
        """
        Compare list object
        """
        for a_index in range(len(a_list)):
            if isinstance(a_list[a_index], dbNodeInfo):
                self._add_alarm_msg("node: {0}".format(a_list[a_index].name))
                b_list_node = self.find_right_node(a_list[a_index], b_list)
                self.compare_dict(a_list[a_index].__dict__, b_list_node.__dict__)
                self._remove_normal_msg("node: {0}".format(a_list[a_index].name))
            elif isinstance(a_list[a_index], instanceInfo):
                self._add_alarm_msg("instance: {0}".format(str(a_list[a_index].instanceId)))
                b_list_inst = self.find_right_instance(a_list[a_index], b_list)
                self.compare_dict(a_list[a_index].__dict__, b_list_inst.__dict__)
                self._remove_normal_msg("instance: {0}".format(str(a_list[a_index].instanceId)))
            elif isinstance(a_list[a_index], dict):
                self._add_alarm_msg(a_list[a_index])
                self.compare_dict(a_list[a_index], b_list[a_index])
                self._remove_normal_msg(a_list[a_index])
            elif isinstance(a_list[a_index], list):
                self._add_alarm_msg(a_list[a_index])
                self.compare_list(a_list[a_index], b_list[a_index])
                self._remove_normal_msg(a_list[a_index])
            elif isinstance(a_list[a_index], str):
                self._add_alarm_msg(a_list[a_index])
                self.compare_string(a_list[a_index], b_list[a_index])
                self._remove_normal_msg(a_list[a_index])
            elif isinstance(a_list[a_index], int):
                self._add_alarm_msg(a_list[a_index])
                self.compare_int(a_list[a_index], b_list[a_index])
                self._remove_normal_msg(a_list[a_index])

    def compare_string(self, a_string, b_string):
        """
        Compare list object
        """
        if a_string != b_string:
            raise Exception((ErrorCode.GAUSS_357["GAUSS_35711"] %
                             self.collect_compare_result[-1]) +
                            "XML configure string item failed: {0} . "
                            "Static config {1}. "
                            "XML config {2}".format(self.collect_compare_result,
                                                    a_string, b_string))

    def compare_int(self, a_integer, b_integer):
        """
        Compare list object
        """
        if a_integer != b_integer:
            raise Exception((ErrorCode.GAUSS_357["GAUSS_35711"] %
                             self.collect_compare_result[-1]) +
                            "XML configure integer item failed. : {0} . "
                            "Static config {1}. "
                            "XML config {2}".format(self.collect_compare_result,
                                                    a_integer, b_integer))

    def compare_dict(self, a_dict, b_dict):
        """
        Compare dict object
        """
        for a_key in a_dict.keys():
            if a_key in ABORT_CHECK_PROPERTY:
                self._remove_normal_msg(a_key)
                continue
            if type(a_dict.get(a_key)) is not type(b_dict.get(a_key)):
                raise Exception("The value type of the XML configuration item [{0}] "
                                "is inconsistent with that "
                                "in the static configuration.".format(a_key))
            self._add_alarm_msg(a_key)
            if isinstance(a_dict.get(a_key), dict):
                self.compare_dict(a_dict.get(a_key), b_dict.get(a_key))
                self._remove_normal_msg(a_key)
            elif isinstance(a_dict.get(a_key), list):
                self.compare_list(a_dict.get(a_key), b_dict.get(a_key))
                self._remove_normal_msg(a_key)
            elif isinstance(a_dict.get(a_key), str):
                pass
            elif isinstance(a_dict.get(a_key), int):
                pass
            else:
                raise Exception("Not support type. key is {0} , "
                                "static value is {1}, XML value is {2}, "
                                "type is {3}".format(a_key, a_dict.get(a_key),
                                                     b_dict.get(a_key), type(a_dict.get(a_key))))

    def compare_cluster_info(self, static_cluster_info, xml_cluster_info):
        """
        Compare cluster_info
        """
        if len(static_cluster_info.dbNodes) >= len(xml_cluster_info.dbNodes):
            raise Exception("XML configuration failed, "
                            "node count in expantion XML must be more than static file.")
        self.compare_dict(static_cluster_info.__dict__, xml_cluster_info.__dict__)


if __name__ == "__main__":
    """
    """
    expansion = Expansion()
    expansion.checkExecutingUser()
    expansion.parseCommandLine()
    expansion.checkParameters()
    expansion.initLogs()
    expansion.getExpansionInfo()
    expansion.checkEnvfile()
    expansion.checkXmlIncludeNewHost()
    expansion.checkExecutingHost()
    expansion.checkTrust()
    expansion.expand_run()
